if(!require(Matrix)){
  install.packages("Matrix", dependencies = TRUE)
}

if(!require(MATA)){
  install.packages("MATA", dependencies = TRUE)
}

if(!require(emmeans)){
  install.packages("emmeans",  dependencies = TRUE)
}

if(!require(lmerTest)){
  install.packages("lmerTest", dependencies = TRUE)
}


# marginal log-likelihood 
log_lik_lmm = function(var, X, Z, y){
  n = length(y)
  q = Matrix::rankMatrix(X)
  sigma2_wp = var[1]
  sigma2_sp = var[2]
  D = sigma2_wp*diag(ncol(Z))
  R = sigma2_sp*diag(n)
  V = Z%*%D%*%t(Z) + R
  Vinv = solve(V)
  beta = solve(t(X)%*%Vinv%*%X)%*%t(X)%*%Vinv%*%y
  r = y - X%*%beta
  - 0.5*log(det(V)) - 0.5*t(r)%*%Vinv%*%r - 0.5*n*log(2*pi)
}

MA = function(response, wholeplot, subplot, wpID, data, method = "aov", 
              alpha = 0.025){
  
  y = eval(substitute(response), data)
  Trt_wp = eval(substitute(wholeplot), data)
  Trt_sp = eval(substitute(subplot), data)
  wp = levels(Trt_wp)
  sp = levels(Trt_sp)
  n1 = length(wp)
  n2 = length(sp)
  
  n = length(y)
  
  wp_ID = eval(substitute(wpID), data)
  
  repID_wp = repID_sp = c()
  # to generate wp_rep ID
  for(i in 1:n1){
    id1 = which(Trt_wp==wp[i])
    wpID_subset = levels(droplevels(wp_ID[id1]))
    for(j in 1:length(wpID_subset)){
      id2 = which(wp_ID == wpID_subset[j])
      repID_wp[id2] = seq(length(id2))
    }
  }
  
  # to generate sp_rep ID
  for(i in 1:n1){
    for(j in 1:n2){
      for(k in 1:length(wpID_subset)){
        id1 = which(Trt_wp==wp[i] & Trt_sp == sp[j] & repID_wp == k)
        repID_sp[id1] = 1:length(id1)
      }
    }
  }
  
  repID_wp = as.factor(repID_wp)
  repID_sp = as.factor(repID_sp)
  
  wp_rep = levels(repID_wp)
  sp_rep = levels(repID_sp)
  
  r1 = length(wp_rep)
  r2 = length(sp_rep)
  
  names = c(deparse(substitute(response)), deparse(substitute(wholeplot)), deparse(substitute(subplot)), 
            deparse(substitute(wpID)))
  
  #### Factorial model ####
  formula_lm = paste(c(names[[1]], "~", names[[2]], "*", names[[3]]), collapse = "")
  formula_lm = as.formula(formula_lm)
  model_lm = lm(formula_lm , data = data)
  anova_lm = anova(model_lm)

  formula_comp_wp = paste(c("pairwise~", names[[2]]), collapse = "")
  
  #### main effects (wp) confidence intervals ####
  wp_comp_means_lm = suppressMessages(confint(emmeans::emmeans(model_lm, specs = as.formula(formula_comp_wp), adjust = "none")$contrasts))
  FactorialContrast_wholeplot = wp_comp_means_lm
  
  sigma2_red = anova_lm$"Mean Sq"[length(anova_lm$`Mean Sq`)]
  R2 = sigma2_red*diag(n)
  V2 = R2
  Vinv2 = solve(V2)
  X = model.matrix(formula_lm, data = data)
  beta_hat_lm = solve(t(X)%*%Vinv2%*%X)%*%t(X)%*%Vinv2%*%y
  r = y - X%*%beta_hat_lm
  
  #marginal AIC
  mar_ll_lm = - 0.5*log(det(V2)) - 0.5*t(r)%*%Vinv2%*%r - 0.5*n*log(2*pi)
  rho_lm = (as.numeric(Matrix::rankMatrix(X))+1) 
  mar_AIC_lm = -2*mar_ll_lm + 2*rho_lm
  
  if(method == "lmer"){
    
    formula_lmer = paste(c(names[[1]], "~", names[[2]], "*", names[[3]], "+", "(1|", names[[4]], ")"), collapse = "")
    formula_lmer = as.formula(formula_lmer)
    model_lmer = lmerTest::lmer(formula_lmer, data)
    anova_lmer = anova(model_lmer)

    #### main effects (wp) confidence intervals ####
    wp_comp_means_lmer = lmerTest::difflsmeans(model_lmer, paste(names[[2]]))
    SplitPlotContrast_wholeplot = wp_comp_means_lmer
    SplitPlotWidth = SplitPlotContrast_wholeplot[6] - SplitPlotContrast_wholeplot[5]

    init = c(as.numeric(summary(model_lmer)$varcor), (summary(model_lmer)$sigma)^2)
    Z = model.matrix(~0+wp_ID)
    mar_ll_lmer =log_lik_lmm(init, X, Z, y)
    mar_AIC_lmer = -2*mar_ll_lmer + 2*(Matrix::rankMatrix(X)+2)
  }
  
  if(method == "aov"){
    
    formula_aov = paste(c(names[[1]], "~", names[[2]], "*", names[[3]], "+", "Error(", names[[4]], ")"), collapse = "")
    formula_aov = as.formula(formula_aov)
    model_aov = aov(formula_aov, data)
    anova_aov = summary(model_aov)
    
    #### main effects (wp) confidence intervals ####
    attr(model_aov, 'call')$formula <- formula_aov
    attr(model_aov, 'call')$data <- data
    wp_comp_means_aov = suppressMessages(confint(emmeans::emmeans(model_aov, specs = as.formula(formula_comp_wp), adjust = "none")$contrasts))
    SplitPlotContrast_wholeplot = wp_comp_means_aov
    SplitPlotWidth = SplitPlotContrast_wholeplot[6] - SplitPlotContrast_wholeplot[5]
    names(SplitPlotWidth) = "WholePlotCIWidth"
    
    #### Marginal AICs ####
    
    Z = model.matrix(~0+wp_ID)
    
      ## using anova table to get estimate of sigma(s) ##
    sigma2_aov_sp = (anova_aov[[2]][[1]])$"Mean Sq"
    wp_ms = (anova_aov[[1]][[1]])$"Mean Sq"[length((anova_aov[[1]][[1]])$"Mean Sq")]
    sigma2_aov_wp = (wp_ms[length(wp_ms)] - sigma2_aov_sp)/(n2*r2)
    
    init = c(sigma2_aov_wp, sigma2_aov_sp)
    mar_ll_aov =log_lik_lmm(init, X, Z, y)
    rho_aov = (as.numeric(Matrix::rankMatrix(X))+2)
    mar_AIC_aov = -2*mar_ll_aov + 2*rho_aov
    
  }

  ######################################################################################
  ####### Model averaging with MATA-WALD
  ######################################################################################
  if(method == "lmer"){
    AIClist = c(mar_AIC_lmer, mar_AIC_lm)
  }
  if(method == "aov"){
    AIClist = c(mar_AIC_aov, mar_AIC_lm)
  }
  min_aic = min(AIClist)
  model_weights = exp(-0.5*(AIClist-min_aic))/sum(exp(-0.5*(AIClist-min_aic)))
  
  ########### pairwise comparisons ########################
  MA_CI_wp_comp = matrix(NA, nrow =  n1*(n1-1)/2, ncol = 2)
 
  if(method == "aov"){
    for(i in 1: (n1*(n1-1)/2)){
      MA_CI_wp_comp[i,] = MATA::mata.wald(c(wp_comp_means_aov[i,2], wp_comp_means_lm[i,2]), c(wp_comp_means_aov[i,3], wp_comp_means_lm[i,3]), 
                                          model_weights, mata.t = TRUE, residual.dfs = c(wp_comp_means_aov[i,4], wp_comp_means_lm[i,4]),
                                          alpha = alpha)
    }
  }
  
  if(method == "lmer"){
    for(i in 1: (n1*(n1-1)/2)){
      MA_CI_wp_comp[i,] = MATA::mata.wald(c(wp_comp_means_lmer[i,1], wp_comp_means_lm[i,2]), c(wp_comp_means_lmer[i,2], wp_comp_means_lm[i,3]), 
                                         model_weights, mata.t = TRUE, residual.dfs = c(wp_comp_means_lmer[i,3], wp_comp_means_lm[i,4]),
                                         alpha = alpha)
    }
  }
  
  names_wp = paste(names[[2]], wp, sep="_")
  combn = combn(names_wp,2)
  names_comp_wp = character(ncol(combn))
  for(i in 1:(n1*(n1-1)/2)){
    name = paste(combn[1,i], combn[2,i], sep=" - ")
    names_comp_wp[i] = name
  }
  
  wp_comp_est = wp_comp_means_lm[,2]

  ### mata.wld model average confidence intervals for main effects (wp)
  MW_CI_wp_comp = data.frame(names_comp_wp, wp_comp_est,  MA_CI_wp_comp)
  
  colnames(MW_CI_wp_comp) = c("contrast", "estimate", "lower.CL", "upper.CL") 
  
  ### Comparison of split-plot whole plot interval width to model-averaged interval width
  MAWidth = MW_CI_wp_comp[4] - MW_CI_wp_comp[3]
  names(MAWidth) = "WholePlotCIWidth"
  W.R = 100*(1 - MAWidth/SplitPlotWidth)
  names(W.R) = "IntervalWidthReductionPct"

  output = list(method, FactorialContrast_wholeplot, SplitPlotContrast_wholeplot, MW_CI_wp_comp, model_weights, sigma2_red, 
                init[1], init[2], SplitPlotWidth, MAWidth, W.R)
  names_wp_var = paste(c(method, "_wholeplot_var"), collapse = "")
  names_sp_var = paste(c(method, "_subplot_var"), collapse = "")
  names(output) = c("sp_method", "FactorialContrast_wholeplot", "SplitPlotContrast_wholeplot", "MAContrast_wholeplot", "AIC_weights", "Factorial_var", 
                    names_wp_var, names_sp_var, "SplitPlotWidth", "MAWidth", "W.R")
  return(output)
}